model_config:
  class_name: tensorflow_asr.models.transducer.conformer>Conformer
  config:
    name: conformer
    speech_config:
      sample_rate: 16000
      frame_ms: 25
      stride_ms: 10
      nfft: 512
      num_feature_bins: 80
      feature_type: log_mel_spectrogram
      augmentation_config:
        feature_augment:
          time_masking:
            prob: 1.0
            num_masks: 10
            mask_factor: -1
            p_upperbound: 0.05
            mask_value: 0
          freq_masking:
            prob: 1.0
            num_masks: 1
            mask_factor: 27
            mask_value: 0
    encoder_subsampling:
      class_name: tensorflow_asr.models.layers.subsampling>Conv2dSubsampling
      config:
        filters: [144, 144]
        kernels: [3, 3]
        strides: [2, 2]
        paddings: ["causal", "causal"]
        norms: ["batch", "batch"]
        activations: ["swish", "swish"]
    encoder_ffm_residual_factor: 0.5
    encoder_mhsam_residual_factor: 1.0
    encoder_convm_residual_factor: 1.0
    encoder_dmodel: 144
    encoder_num_blocks: 16
    encoder_head_size: 36 # == dmodel // num_heads
    encoder_num_heads: 4
    encoder_mha_type: relmha
    encoder_interleave_relpe: True
    encoder_use_attention_causal_mask: False
    encoder_use_attention_auto_mask: True
    encoder_mhsam_use_attention_bias: False
    encoder_kernel_size: 31
    encoder_dropout: 0.1
    encoder_padding: causal
    encoder_memory_length: null
    prediction_label_encode_mode: embedding
    prediction_embed_dim: 320
    prediction_num_rnns: 1
    prediction_rnn_units: 320
    prediction_rnn_type: lstm
    prediction_rnn_implementation: 2
    prediction_rnn_unroll: False
    prediction_layer_norm: True
    prediction_projection_units: 0
    joint_dim: 320
    prejoint_encoder_linear: True
    prejoint_prediction_linear: True
    postjoint_linear: False
    joint_activation: tanh
    joint_mode: add
    blank: 0
    vocab_size: {{decoder_config.vocabsize}}
    kernel_regularizer:
      class_name: l2
      config:
        l2: 1e-6

learning_config:
  optimizer_config:
    class_name: Adam
    config:
      learning_rate:
        class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule
        config:
          dmodel: 144
          warmup_steps: 10000
          max_lr: 0.05/(144**0.5)
          min_lr: null
          scale: 2.0
      beta_1: 0.9
      beta_2: 0.98
      epsilon: 1e-9
      weight_decay: 1e-6

  gwn_config: null

  gradn_config: null

  batch_size: 2
  ga_steps: 16
  num_epochs: 300

  callbacks:
    - class_name: tensorflow_asr.callbacks>TerminateOnNaN
      config: {}
    - class_name: tensorflow_asr.callbacks>ModelCheckpoint
      config:
        filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5
        save_best_only: False
        save_weights_only: True
        save_freq: epoch
    - class_name: tensorflow_asr.callbacks>TensorBoard
      config:
        log_dir: {{modeldir}}/tensorboard
        histogram_freq: 0
        write_graph: False
        write_images: False
        write_steps_per_second: False
        update_freq: batch
        profile_batch: 0
    - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore
      config:
        model_handle: {{kaggle_model_handle}}
        model_dir: {{modeldir}}
        save_freq: epoch